import os
import string
import json

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

#sns.set()
#------------------------------------------------------------------------------
def count(labels, preds, val_label, val_pred):
    r = 0
    for idx, m in enumerate(labels):
        if m == val_label and preds[idx] == val_pred:
            r += 1
    return r

def confusion_matrix(labels, preds, texts, order = None, return_df = True, force_int = True):

    if order is None:
        order = np.unique(labels)
    if force_int:
        C = np.zeros((len(np.unique(labels)), len(np.unique(labels))), dtype=np.int64)
    else:
        C = np.zeros((len(np.unique(labels)), len(np.unique(labels))))

    for idx_1, m in enumerate(order):
        for idx_2, n in enumerate(order):
            r = count(labels, preds, m, n)
            if force_int:
                C[idx_1, idx_2] = int(r)
            else:
                C[idx_1, idx_2] = r

    index = list()
    for m in order:
        index.append(texts[m])

    df = pd.DataFrame(data=C, index=index, columns=index)
    if return_df:
        return C, df
    else:
        return C
    
def plot_confusion_matrix(df, figsize = [10, 10], cmap = 'Blues', annot=True, fmt='d', block = False):
    fig = plt.figure()
    sns.heatmap(df, cmap=cmap, annot=annot, fmt=fmt)
    plt.ylabel("Actual label")
    plt.xlabel("Predicted label")
    if block:
        plt.show()
    return fig
#------------------------------------------------------------------------------

repository_base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

base = os.path.join(repository_base, "results")

n_subjects = 11
n_class = 3
texts = {-100:'Attended',-500:'Nonattended'}
order = [-100, -500]

alphabet = list(string.ascii_uppercase)

subjects = list()
for m in range(n_subjects):
    subjects.append("sub-%s"%alphabet[m])

#------------------------------------------------------------------------------
# grand averaged

C = dict()
for m in range(1, n_class+1):
    C[m] = list()

for subject in subjects:
    with open(os.path.join(base, "%s_classification_scores.json"%subject), 'r') as f:
        data = json.load(f)
    
    for c in range(1, n_class+1):
        _C, df = confusion_matrix(data[str(c)]['labels'], data[str(c)]['preds'], texts, order)
        C[c].append(_C)
    
for m in range(1, n_class+1):
    C[m] = np.mean(np.array(C[m]), axis=0)

index = list()
for m in order:
    index.append(texts[m])
    
for m in range(1, n_class+1):
    df = pd.DataFrame(data=C[m], index=index, columns=index)
    fig = plot_confusion_matrix(df, fmt='.2f')    
    plt.savefig(os.path.join(base, "confusion_grand_%d.png"%m), dpi=300)